#!/usr/bin/env python

import multiprocessing as mp
import numpy as np
import os
import time

from math import pi

import ndsimulator.onerun.minimize as mz
from ndsimulator.pylotrun import PylotRun
from ndsimulator.replica.commitor import Commitor_Replica
from ndsimulator.control import Control


def main():

    ncpus = 1
    modeldir = "../../../Models/mueller_brown_label"

    init_list = np.loadtxt("init_config.dat")
    allz = init_list[:, 0]
    minz = np.min(allz)
    maxz = np.max(allz)
    gridz = 1.0
    r = np.arange(-24, 20, gridz)
    k = np.ones(len(r)) * 0.1
    case = len(r)
    x = []

    for idx in range(case):
        ind = np.argmin(np.abs(r[idx] - allz))
        x += [init_list[ind, 1:]]

    shared_input = {
        "filetitle": "0",
        "workdir": f"umb",
        "method": "md",
        "biases": ["umb"],
        "umb_n": 1,
        "umb_k": [0.125],
        "umb_r0": [[0]],
        "colvardim": 1,
        "colvarfunc": "tfmodel",
        "tf_folder": modeldir,
        "tf_inputname": "X:0",
        "tf_outputname": "z_x:0",
        "tf_gradname": "dz_dx:0",
        "potfunc": "5dmueller",
        "true_colvarfunc": "5dto2d",
        "integrate": "2nd-langevin",
        "md_gamma": 0.01,
        "md_fixdt": True,
        "steps": 10000000,
        "dt": 1.0,
        "temp": 300.0,
        "mass": 1.0,
        "ndim": 5,
        "track_pvf": False,
        "x0": [18.0, 0.0, 18.0, 0.0, 100.0],
        "stat_freq": 1000,
        "dump": True,
        "dump_freq": 1000,
        "movie": False,
        "screen": False,
        "oneplot": False,
    }

    results = []
    t0 = time.time()
    if not os.path.exists(shared_input["workdir"]):
        os.mkdir(shared_input["workdir"])
    with open(f"log.umb", "w+", 1) as fout:
        if ncpus > 1:
            with mp.Pool(ncpus) as pool:
                for i in range(case):
                    print("compute", i, r[i], x[i], k[i])
                    print("compute", i, r[i], x[i], k[i], file=fout)
                    results.append(
                        pool.apply_async(umb, args=(i, x[i], r[i], k[i], shared_input))
                    )
                for i in range(case):
                    results[i].get()
        else:
            for i in range(case):
                print("compute", i, r[i], x[i], k[i])
                print("compute", i, r[i], x[i], k[i], file=fout)
                umb(i, x[i], r[i], k[i], shared_input)

    print("timer", time.time() - t0)


def umb(idx, x0, R0, K, sinput):

    # read parameters
    allcont = Control()
    allcont.copy_args(sinput)
    allcont.filetitle = f"{idx}"
    allcont.umb_r0 = [[R0]]
    allcont.umb_k = [K]
    allcont.x0 = x0

    random = np.random.RandomState(idx * 7 + 1034687)

    simulation = PylotRun(cont=allcont, random=random)
    simulation.begin()
    simulation.run()
    simulation.end()


if __name__ == "__main__":
    main()
